# Copyright (c) Anish Athalye. Released under GPLv3.

import time
from collections import OrderedDict
from functools import reduce

import numpy as np
import tensorflow.compat.v1 as tf
from PIL import Image

import vgg

tf.disable_v2_behavior()

CONTENT_LAYERS = ("relu4_2", "relu5_2")
STYLE_LAYERS = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")


def get_loss_vals(loss_store):
    return OrderedDict((key, val.eval()) for key, val in loss_store.items())


def print_progress(loss_vals):
    for key, val in loss_vals.items():
        print("{:>13s} {:g}".format(key + " loss:", val))


def stylize(
    network,
    initial,
    initial_noiseblend,
    content,
    styles,
    preserve_colors,
    iterations,
    content_weight,
    content_weight_blend,
    style_weight,
    style_layer_weight_exp,
    style_blend_weights,
    tv_weight,
    learning_rate,
    beta1,
    beta2,
    epsilon,
    pooling,
    print_iterations=None,
    checkpoint_iterations=None,
):
    """
    Stylize images.

    This function yields tuples (iteration, image, loss_vals) at every
    iteration. However `image` and `loss_vals` are None by default. Each
    `checkpoint_iterations`, `image` is not None. Each `print_iterations`,
    `loss_vals` is not None.

    `loss_vals` is a dict with loss values for the current iteration, e.g.
    ``{'content': 1.23, 'style': 4.56, 'tv': 7.89, 'total': 13.68}``.

    :rtype: iterator[tuple[int,image]]
    """
    shape = (1,) + content.shape
    style_shapes = [(1,) + style.shape for style in styles]
    content_features = {}
    style_features = [{} for _ in styles]

    vgg_weights, vgg_mean_pixel = vgg.load_net(network)

    layer_weight = 1.0
    style_layers_weights = {}
    for style_layer in STYLE_LAYERS:
        style_layers_weights[style_layer] = layer_weight
        layer_weight *= style_layer_weight_exp

    # normalize style layer weights
    layer_weights_sum = 0
    for style_layer in STYLE_LAYERS:
        layer_weights_sum += style_layers_weights[style_layer]
    for style_layer in STYLE_LAYERS:
        style_layers_weights[style_layer] /= layer_weights_sum

    # compute content features in feedforward mode
    g = tf.Graph()
    with g.as_default(), g.device("/cpu:0"), tf.Session() as sess:
        image = tf.placeholder("float", shape=shape)
        net = vgg.net_preloaded(vgg_weights, image, pooling)
        content_pre = np.array([vgg.preprocess(content, vgg_mean_pixel)])
        for layer in CONTENT_LAYERS:
            content_features[layer] = net[layer].eval(feed_dict={image: content_pre})

    # compute style features in feedforward mode
    for i in range(len(styles)):
        g = tf.Graph()
        with g.as_default(), g.device("/cpu:0"), tf.Session() as sess:
            image = tf.placeholder("float", shape=style_shapes[i])
            net = vgg.net_preloaded(vgg_weights, image, pooling)
            style_pre = np.array([vgg.preprocess(styles[i], vgg_mean_pixel)])
            for layer in STYLE_LAYERS:
                features = net[layer].eval(feed_dict={image: style_pre})
                features = np.reshape(features, (-1, features.shape[3]))
                gram = np.matmul(features.T, features) / features.size
                style_features[i][layer] = gram

    initial_content_noise_coeff = 1.0 - initial_noiseblend

    # make stylized image using backpropogation
    with tf.Graph().as_default():
        if initial is None:
            initial = tf.random_normal(shape) * 0.256
        else:
            initial = np.array([vgg.preprocess(initial, vgg_mean_pixel)])
            initial = initial.astype("float32")
            initial = (initial) * initial_content_noise_coeff + (tf.random_normal(shape) * 0.256) * (
                1.0 - initial_content_noise_coeff
            )
        image = tf.Variable(initial)
        net = vgg.net_preloaded(vgg_weights, image, pooling)

        # content loss
        content_layers_weights = {}
        content_layers_weights["relu4_2"] = content_weight_blend
        content_layers_weights["relu5_2"] = 1.0 - content_weight_blend

        content_loss = 0
        content_losses = []
        for content_layer in CONTENT_LAYERS:
            content_losses.append(
                content_layers_weights[content_layer]
                * content_weight
                * (
                    2
                    * tf.nn.l2_loss(net[content_layer] - content_features[content_layer])
                    / content_features[content_layer].size
                )
            )
        content_loss += reduce(tf.add, content_losses)

        # style loss
        style_loss = 0
        for i in range(len(styles)):
            style_losses = []
            for style_layer in STYLE_LAYERS:
                layer = net[style_layer]
                _, height, width, number = map(lambda i: i.value, layer.get_shape())
                size = height * width * number
                feats = tf.reshape(layer, (-1, number))
                gram = tf.matmul(tf.transpose(feats), feats) / size
                style_gram = style_features[i][style_layer]
                style_losses.append(
                    style_layers_weights[style_layer] * 2 * tf.nn.l2_loss(gram - style_gram) / style_gram.size
                )
            style_loss += style_weight * style_blend_weights[i] * reduce(tf.add, style_losses)

        # total variation denoising
        tv_y_size = _tensor_size(image[:, 1:, :, :])
        tv_x_size = _tensor_size(image[:, :, 1:, :])
        tv_loss = (
            tv_weight
            * 2
            * (
                (tf.nn.l2_loss(image[:, 1:, :, :] - image[:, : shape[1] - 1, :, :]) / tv_y_size)
                + (tf.nn.l2_loss(image[:, :, 1:, :] - image[:, :, : shape[2] - 1, :]) / tv_x_size)
            )
        )

        # total loss
        loss = content_loss + style_loss + tv_loss

        # We use OrderedDict to make sure we have the same order of loss types
        # (content, tv, style, total) as defined by the initial costruction of
        # the loss_store dict. This is important for print_progress() and
        # saving loss_arrs (column order) in the main script.
        #
        # Subtle Gotcha (tested with Python 3.5): The syntax
        # OrderedDict(key1=val1, key2=val2, ...) does /not/ create the same
        # order since, apparently, it first creates a normal dict with random
        # order (< Python 3.7) and then wraps that in an OrderedDict. We have
        # to pass in a data structure which is already ordered. I'd call this a
        # bug, since both constructor syntax variants result in different
        # objects. In 3.6, the order is preserved in dict() in CPython, in 3.7
        # they finally made it part of the language spec. Thank you!
        loss_store = OrderedDict([("content", content_loss), ("style", style_loss), ("tv", tv_loss), ("total", loss)])

        # optimizer setup
        train_step = tf.train.AdamOptimizer(learning_rate, beta1, beta2, epsilon).minimize(loss)

        # optimization
        best_loss = float("inf")
        best = None
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print("Optimization started...")
            if print_iterations and print_iterations != 0:
                print_progress(get_loss_vals(loss_store))
            iteration_times = []
            start = time.time()
            for i in range(iterations):
                iteration_start = time.time()
                if i > 0:
                    elapsed = time.time() - start
                    # take average of last couple steps to get time per iteration
                    remaining = np.mean(iteration_times[-10:]) * (iterations - i)
                    print(f"Iteration {i + 1:4d}/{iterations:4d} ({hms(elapsed)} elapsed, {hms(remaining)} remaining)")
                else:
                    print(f"Iteration {i + 1:4d}/{iterations:4d}")
                train_step.run()

                last_step = i == iterations - 1
                if last_step or (print_iterations and i % print_iterations == 0):
                    loss_vals = get_loss_vals(loss_store)
                    print_progress(loss_vals)
                else:
                    loss_vals = None

                if (checkpoint_iterations and i % checkpoint_iterations == 0) or last_step:
                    this_loss = loss.eval()
                    if this_loss < best_loss:
                        best_loss = this_loss
                        best = image.eval()

                    img_out = vgg.unprocess(best.reshape(shape[1:]), vgg_mean_pixel)

                    if preserve_colors:
                        original_image = np.clip(content, 0, 255)
                        styled_image = np.clip(img_out, 0, 255)

                        # Luminosity transfer steps:
                        # 1. Convert stylized RGB->grayscale accoriding to Rec.601 luma (0.299, 0.587, 0.114)
                        # 2. Convert stylized grayscale into YUV (YCbCr)
                        # 3. Convert original image into YUV (YCbCr)
                        # 4. Recombine (stylizedYUV.Y, originalYUV.U, originalYUV.V)
                        # 5. Convert recombined image from YUV back to RGB

                        # 1
                        styled_grayscale = rgb2gray(styled_image)
                        styled_grayscale_rgb = gray2rgb(styled_grayscale)

                        # 2
                        styled_grayscale_yuv = np.array(
                            Image.fromarray(styled_grayscale_rgb.astype(np.uint8)).convert("YCbCr")
                        )

                        # 3
                        original_yuv = np.array(Image.fromarray(original_image.astype(np.uint8)).convert("YCbCr"))

                        # 4
                        w, h, _ = original_image.shape
                        combined_yuv = np.empty((w, h, 3), dtype=np.uint8)
                        combined_yuv[..., 0] = styled_grayscale_yuv[..., 0]
                        combined_yuv[..., 1] = original_yuv[..., 1]
                        combined_yuv[..., 2] = original_yuv[..., 2]

                        # 5
                        img_out = np.array(Image.fromarray(combined_yuv, "YCbCr").convert("RGB"))
                else:
                    img_out = None

                yield i + 1 if last_step else i, img_out, loss_vals

                iteration_end = time.time()
                iteration_times.append(iteration_end - iteration_start)


def _tensor_size(tensor):
    from operator import mul

    return reduce(mul, (d.value for d in tensor.get_shape()), 1)


def rgb2gray(rgb):
    return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])


def gray2rgb(gray):
    w, h = gray.shape
    rgb = np.empty((w, h, 3), dtype=np.float32)
    rgb[:, :, 2] = rgb[:, :, 1] = rgb[:, :, 0] = gray
    return rgb


def hms(seconds):
    seconds = int(seconds)
    hours = seconds // (60 * 60)
    minutes = (seconds // 60) % 60
    seconds = seconds % 60
    if hours > 0:
        return f"{hours:d} hr {minutes:d} min"
    elif minutes > 0:
        return f"{minutes:d} min {seconds:d} sec"
    else:
        return f"{seconds:d} sec"
